import tensorflow as tf
import numpy as np

np.random.seed(1234)
import os
import time
import datetime
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from builddata_ecir import *
from model_R_MeN_SP_CNN import RMeN_SP
from scipy.stats import rankdata

# Parameters
# ==================================================
parser = ArgumentParser("RMeN_SP", formatter_class=ArgumentDefaultsHelpFormatter, conflict_handler='resolve')

parser.add_argument("--data", default="./data/", help="Data sources.")
parser.add_argument("--run_folder", default="../", help="Data sources.")
parser.add_argument("--name", default="SEARCH17", help="Name of the dataset.")
parser.add_argument("--embedding_dim", default=200, type=int, help="Dimensionality of character embedding")
parser.add_argument("--learning_rate", default=0.0001, type=float, help="Learning rate")
parser.add_argument("--batch_size", default=16, type=int, help="Batch Size")
parser.add_argument("--num_epochs", default=50, type=int, help="Number of training epochs")
parser.add_argument("--saveStep", default=1, type=int, help="")
parser.add_argument("--neg_ratio", default=1.0, type=float, help="Number of negative triples generated by positive")
parser.add_argument("--allow_soft_placement", default=True, type=bool, help="Allow device soft device placement")
parser.add_argument("--log_device_placement", default=False, type=bool, help="Log placement of ops on devices")
parser.add_argument("--model_name", default='SEARCH17', help="")
parser.add_argument("--dropout_keep_prob", default=1.0, type=float, help="Dropout keep probability")
parser.add_argument("--num_heads", default=2, type=int, help="Number of attention heads. 1 2 4")
parser.add_argument("--memory_slots", default=1, type=int, help="Number of memory slots. 1 2 4")
parser.add_argument("--head_size", default=128, type=int, help="")
parser.add_argument("--gate_style", default='memory', help="unit,memory")
parser.add_argument("--attention_mlp_layers", default=2, type=int, help="2 3 4")
parser.add_argument("--use_pos", default=1, type=int, help="1 when using positional embeddings. Otherwise.")
parser.add_argument("--num_filters", default=20, type=int, help="Number of filters per filter size")

args = parser.parse_args()
print(args)

def computeMRR(lstRanks):
    rr = 0.0
    for tmp in lstRanks:
        rr += 1.0/ tmp
    return rr / len(lstRanks)

def computeP1(lstRanks):
    p1 = 0.0
    for tmp in lstRanks:
        if tmp == 1:
            p1 += 1
    return p1 / len(lstRanks)

# Load data
print("Loading data...")

train_triples, train_rank_triples, train_val_triples, valid_triples, valid_rank_triples, valid_val_triples, \
            test_triples, test_rank_triples, test_val_triples, query_indexes, user_indexes, doc_indexes, \
            indexes_query, indexes_user, indexes_doc = build_data_ecir()
data_size = len(train_triples)
train_batch = Batch_Loader_ecir(train_triples, train_val_triples, batch_size=args.batch_size)

assert args.embedding_dim % 200 == 0

pretrained_query = init_dataset_ecir(args.data + args.name + '/query2vec.200.init')
pretrained_user = init_dataset_ecir(args.data + args.name + '/user2vec.200.init')
pretrained_doc = init_dataset_ecir(args.data + args.name + '/doc2vec.200.init')

print("Using pre-trained initialization.")

lstEmbedQuery = assignEmbeddings(pretrained_query, query_indexes)
lstEmbedUser = assignEmbeddings(pretrained_user, user_indexes)
lstEmbedDoc = assignEmbeddings(pretrained_doc, doc_indexes)

lstEmbedQuery = np.array(lstEmbedQuery, dtype=np.float32)
lstEmbedUser = np.array(lstEmbedUser, dtype=np.float32)
lstEmbedDoc = np.array(lstEmbedDoc, dtype=np.float32)

print("Loading data... finished!")

# Training
# ==================================================
with tf.Graph().as_default():
    tf.compat.v1.set_random_seed(1234)
    session_conf = tf.compat.v1.ConfigProto(allow_soft_placement=args.allow_soft_placement,
                                  log_device_placement=args.log_device_placement)
    session_conf.gpu_options.allow_growth = True
    sess = tf.compat.v1.Session(config=session_conf)
    with sess.as_default():
        global_step = tf.Variable(0, name="global_step", trainable=False)
        relkb = RMeN_SP(
            batch_size=20*args.batch_size,
            initialization=[lstEmbedQuery, lstEmbedUser, lstEmbedDoc],
            embedding_size=args.embedding_dim,
            num_heads=args.num_heads,
            mem_slots=args.memory_slots,
            head_size=args.head_size,
            attention_mlp_layers=args.attention_mlp_layers,
            use_pos=args.use_pos,
            #num_filters=args.num_filters,
        )

        #Optimizer
        optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=args.learning_rate)
        # optimizer = tf.compat.v1.train.RMSPropOptimizer(learning_rate=args.learning_rate)
        # optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=args.learning_rate)
        grads_and_vars = optimizer.compute_gradients(relkb.loss)
        train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)

        out_dir = os.path.abspath(os.path.join(args.run_folder, "runs_RMeN_SP_Drop1", args.model_name))
        print("Writing to {}\n".format(out_dir))

        # Checkpoint directory. Tensorflow assumes this directory already exists so we need to create it
        checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
        checkpoint_prefix = os.path.join(checkpoint_dir, "model")
        if not os.path.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)

        # Initialize all variables
        sess.run(tf.global_variables_initializer())

        def train_step(x_batch, y_batch):
            """
            A single training step
            """
            feed_dict = {
                relkb.input_x: x_batch,
                relkb.input_y: y_batch,
                relkb.dropout_keep_prob: args.dropout_keep_prob
            }
            _, step, loss = sess.run([train_op, global_step, relkb.loss], feed_dict)
            return loss

        # Predict function to predict scores for test data
        def predict(x_batch, y_batch):
            feed_dict = {
                relkb.input_x: x_batch,
                relkb.input_y: y_batch,
                relkb.dropout_keep_prob: 1.0
            }
            scores = sess.run([relkb.predictions], feed_dict)
            return scores

        def test_prediction(x_batch, y_batch, lstOriginalRank):

            new_x_batch = np.concatenate(x_batch)
            new_y_batch = np.concatenate(y_batch, axis=0)

            while len(new_x_batch) % (args.batch_size * 20) != 0:
                new_x_batch = np.append(new_x_batch, np.array([new_x_batch[-1]]), axis=0)
                new_y_batch = np.append(new_y_batch, np.array([new_y_batch[-1]]), axis=0)

            results = []
            listIndexes = range(0, len(new_x_batch), 20 * args.batch_size)
            for tmpIndex in range(len(listIndexes) - 1):
                results = np.append(results,
                                    predict(new_x_batch[listIndexes[tmpIndex]:listIndexes[tmpIndex + 1]],
                                            new_y_batch[listIndexes[tmpIndex]:listIndexes[tmpIndex + 1]]))
            results = np.append(results,
                                predict(new_x_batch[listIndexes[-1]:], new_y_batch[listIndexes[-1]:]))

            lstresults = []
            _start = 0
            for tmp in lstOriginalRank:
                _end = _start + len(tmp)
                tmp_results = np.reshape(results[_start:_end], -1)
                results_with_id = rankdata(tmp_results)
                lstresults.append(results_with_id[0])
                _start = _end

            return lstresults


        wri = open(checkpoint_prefix + '.cls.' + '.txt', 'w')

        wrip1 = open(checkpoint_prefix + '.cls.P1' + '.txt', 'w')

        lstvalid_mrr = []
        lstvalid_p1 = []
        lsttest_mrr = []
        num_batches_per_epoch = int((data_size - 1) / (args.batch_size)) + 1
        for epoch in range(args.num_epochs):
            for batch_num in range(num_batches_per_epoch):
                x_batch, y_batch = train_batch()
                train_step(x_batch, y_batch)
                current_step = tf.compat.v1.train.global_step(sess, global_step)

            valid_results = test_prediction(valid_triples, valid_val_triples, valid_rank_triples)
            test_results = test_prediction(test_triples, test_val_triples, test_rank_triples)
            valid_mrr = computeMRR(valid_results)
            valid_p1 = computeP1(valid_results)
            test_mrr = computeMRR(test_results)
            test_p1 = computeP1(test_results)
            lstvalid_mrr.append(valid_mrr)
            lstvalid_p1.append(valid_p1)
            lsttest_mrr.append([test_mrr, test_p1])

            wri.write("epoch " + str(epoch) + ": " + str(valid_mrr) + " : " + str(test_mrr)  + " : " + str(test_p1) + "\n")

            wrip1.write("epoch " + str(epoch) + ": " + str(valid_p1) + " : " + str(test_mrr)  + " : " + str(test_p1) + "\n")

        index_valid_max = np.argmax(lstvalid_mrr)
        wri.write("\n--------------------------\n")
        wri.write("\nBest mrr in valid at epoch " + str(index_valid_max) + ": " + str(lstvalid_mrr[index_valid_max]) + "\n")
        wri.write("\nMRR and P1 in test: " + str(lsttest_mrr[index_valid_max][0]) + " " + str(lsttest_mrr[index_valid_max][1]) + "\n")
        wri.close()

        index_valid_max_p1 = np.argmax(lstvalid_p1)
        wrip1.write("\n--------------------------\n")
        wrip1.write("\nBest mrr in valid at epoch " + str(index_valid_max_p1) + ": " + str(lstvalid_p1[index_valid_max_p1]) + "\n")
        wrip1.write("\nMRR and P1 in test: " + str(lsttest_mrr[index_valid_max_p1][0]) + " " + str(lsttest_mrr[index_valid_max_p1][1]) + "\n")
        wrip1.close()


